# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Shows how to generate image conditioning from a source image with the
Amazon Titan Image Generator G1 V2 model (on demand).
"""
import base64
import io
import json
import boto3
from PIL import Image
from botocore.exceptions import ClientError
import os
import time
from src.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, COVER_PROMPT


class ImageError(Exception):
    "Custom exception for errors returned by Amazon Titan Image Generator V2"

    def __init__(self, message):
        self.message = message


def generate_image(model_id, body):
    """
    Generate an image using Amazon Titan Image Generator V2 model on demand.
    Args:
        model_id (str): The model ID to use.
        body (str) : The request body to use.
    Returns:
        image_bytes (bytes): The image generated by the model.
    """

    bedrock = boto3.client(
        service_name="bedrock-runtime",
        region_name="us-east-1",
        aws_access_key_id=AWS_ACCESS_KEY_ID,
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    )

    accept = "application/json"
    content_type = "application/json"

    response = bedrock.invoke_model(
        body=body, modelId=model_id, accept=accept, contentType=content_type
    )
    response_body = json.loads(response.get("body").read())

    base64_image = response_body.get("images")[0]
    base64_bytes = base64_image.encode("ascii")
    image_bytes = base64.b64decode(base64_bytes)

    finish_reason = response_body.get("error")

    if finish_reason is not None:
        raise ImageError(f"Image generation error. Error is {finish_reason}")

    return image_bytes


def amazon_generate_cover(your_file_path):
    """
    Entrypoint for Amazon Titan Image Generator V2 example.
    """
    try:
        model_id = "amazon.titan-image-generator-v2:0"

        # Read image from file and encode it as base64 string.
        with open(your_file_path, "rb") as image_file:
            input_image = base64.b64encode(image_file.read()).decode("utf8")

        body = json.dumps(
            {
                "taskType": "TEXT_IMAGE",
                "textToImageParams": {
                    "text": COVER_PROMPT,
                    "negativeText": "",
                    "conditionImage": input_image,
                    "controlMode": "CANNY_EDGE",
                },
                "imageGenerationConfig": {
                    "numberOfImages": 1,
                    "height": 1024,
                    "width": 1024,
                    "cfgScale": 8.0,
                },
            }
        )

        image_bytes = generate_image(model_id=model_id, body=body)
        image = Image.open(io.BytesIO(image_bytes))
        cover_name = time.strftime("%Y%m%d%H%M%S") + ".png"
        temp_cover_path = os.path.join(os.path.dirname(your_file_path), cover_name)
        image.save(temp_cover_path)
        os.remove(your_file_path)
        return temp_cover_path

    except ClientError as err:
        message = err.response["Error"]["Message"]
        print("A client error occured: " + format(message), flush=True)
        return None
    except ImageError as err:
        print(err.message, flush=True)
        return None

    else:
        print(
            f"Finished generating image with Amazon Titan Image Generator V2 model {model_id}.",
            flush=True,
        )


if __name__ == "__main__":
    print(amazon_generate_cover(""))
